// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package vulnerability

import (
	"database/sql"
	"errors"
	"fmt"
	"time"

	"github.com/lib/pq"
	log "github.com/sirupsen/logrus"

	"github.com/coreos/clair/database"
	"github.com/coreos/clair/database/pgsql/feature"
	"github.com/coreos/clair/database/pgsql/monitoring"
	"github.com/coreos/clair/database/pgsql/page"
	"github.com/coreos/clair/database/pgsql/util"
	"github.com/coreos/clair/ext/versionfmt"
	"github.com/coreos/clair/pkg/pagination"
)

const (
	searchVulnerability = `
		SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format 
		FROM vulnerability AS v, namespace AS n
		WHERE v.namespace_id = n.id
		AND v.name = $1
		AND n.name = $2
		AND v.deleted_at IS NULL
		`

	searchVulnerabilityByID = `
		SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format
		FROM vulnerability AS v, namespace AS n
		WHERE v.namespace_id = n.id
			AND v.id = $1`

	insertVulnerability = `
		WITH ns AS (
			SELECT id FROM namespace WHERE name = $6 AND version_format = $7
		)
		INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at)
		VALUES((SELECT id FROM ns), $1, $2, $3, $4, $5, CURRENT_TIMESTAMP)
		RETURNING id`

	removeVulnerability = `
		UPDATE Vulnerability
		SET deleted_at = CURRENT_TIMESTAMP
		WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1)
			AND name = $2
			AND deleted_at IS NULL
		RETURNING id`

	searchNotificationVulnerableAncestry = `
		SELECT DISTINCT ON (a.id)
			 a.id, a.name
		 FROM vulnerability_affected_namespaced_feature AS vanf,
			 ancestry_layer AS al, ancestry_feature AS af, ancestry AS a
		 WHERE vanf.vulnerability_id = $1
			 AND a.id >= $2
			 AND al.ancestry_id = a.id
			 AND al.id = af.ancestry_layer_id
			 AND af.namespaced_feature_id = vanf.namespaced_feature_id
		 ORDER BY a.id ASC
		 LIMIT $3;`
)

func queryInvalidateVulnerabilityCache(count int) string {
	return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature 
		WHERE vulnerability_id IN (%s)`,
		util.QueryString(1, count))
}

// NOTE(Sida): Every search query can only have count less than postgres set
// stack depth. IN will be resolved to nested OR_s and the parser might exceed
// stack depth. TODO(Sida): Generate different queries for different count: if
// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for
// count > 65535, use is expected to split data into batches.
func querySearchLastDeletedVulnerabilityID(count int) string {
	return fmt.Sprintf(`
			SELECT vid, vname, nname FROM (
				SELECT v.id AS vid, v.name AS vname, n.name AS nname, 
				row_number() OVER (
					PARTITION by (v.name, n.name) 
					ORDER BY v.deleted_at DESC
					) AS rownum 
				FROM vulnerability AS v, namespace AS n 
				WHERE v.namespace_id = n.id 
					AND (v.name, n.name) IN ( %s )
					AND v.deleted_at IS NOT NULL
				) tmp WHERE rownum <= 1`,
		util.QueryString(2, count))
}

func querySearchNotDeletedVulnerabilityID(count int) string {
	return fmt.Sprintf(`
		SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n
		WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s) 
		AND v.deleted_at IS NULL`,
		util.QueryString(2, count))
}

type affectedAncestry struct {
	name string
	id   int64
}

type affectRelation struct {
	vulnerabilityID     int64
	namespacedFeatureID int64
	addedBy             int64
}

type affectedFeatureRows struct {
	rows map[int64]database.AffectedFeature
}

func FindVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
	defer monitoring.ObserveQueryTime("findVulnerabilities", "", time.Now())
	resultVuln := make([]database.NullableVulnerability, len(vulnerabilities))
	vulnIDMap := map[int64][]*database.NullableVulnerability{}

	//TODO(Sida): Change to bulk search.
	stmt, err := tx.Prepare(searchVulnerability)
	if err != nil {
		return nil, err
	}

	// load vulnerabilities
	for i, key := range vulnerabilities {
		var (
			id   sql.NullInt64
			vuln = database.NullableVulnerability{
				VulnerabilityWithAffected: database.VulnerabilityWithAffected{
					Vulnerability: database.Vulnerability{
						Name: key.Name,
						Namespace: database.Namespace{
							Name: key.Namespace,
						},
					},
				},
			}
		)

		err := stmt.QueryRow(key.Name, key.Namespace).Scan(
			&id,
			&vuln.Description,
			&vuln.Link,
			&vuln.Severity,
			&vuln.Metadata,
			&vuln.Namespace.VersionFormat,
		)

		if err != nil && err != sql.ErrNoRows {
			stmt.Close()
			return nil, util.HandleError("searchVulnerability", err)
		}
		vuln.Valid = id.Valid
		resultVuln[i] = vuln
		if id.Valid {
			vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i])
		}
	}

	if err := stmt.Close(); err != nil {
		return nil, util.HandleError("searchVulnerability", err)
	}

	toQuery := make([]int64, 0, len(vulnIDMap))
	for id := range vulnIDMap {
		toQuery = append(toQuery, id)
	}

	// load vulnerability affected features
	rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery))
	if err != nil {
		return nil, util.HandleError("searchVulnerabilityAffected", err)
	}

	for rows.Next() {
		var (
			id int64
			f  database.AffectedFeature
		)

		err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion)
		if err != nil {
			return nil, util.HandleError("searchVulnerabilityAffected", err)
		}

		for _, vuln := range vulnIDMap[id] {
			f.Namespace = vuln.Namespace
			vuln.Affected = append(vuln.Affected, f)
		}
	}

	return resultVuln, nil
}

func InsertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) error {
	defer monitoring.ObserveQueryTime("insertVulnerabilities", "all", time.Now())
	// bulk insert vulnerabilities
	vulnIDs, err := insertVulnerabilities(tx, vulnerabilities)
	if err != nil {
		return err
	}

	// bulk insert vulnerability affected features
	vulnFeatureMap, err := InsertVulnerabilityAffected(tx, vulnIDs, vulnerabilities)
	if err != nil {
		return err
	}

	return CacheVulnerabiltyAffectedNamespacedFeature(tx, vulnFeatureMap)
}

// insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided.
//
// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided.
func InsertVulnerabilityAffected(tx *sql.Tx, vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) {
	var (
		vulnFeature = map[int64]affectedFeatureRows{}
		affectedID  int64
	)

	types, err := feature.GetFeatureTypeMap(tx)
	if err != nil {
		return nil, err
	}

	stmt, err := tx.Prepare(insertVulnerabilityAffected)
	if err != nil {
		return nil, util.HandleError("insertVulnerabilityAffected", err)
	}

	defer stmt.Close()
	for i, vuln := range vulnerabilities {
		// affected feature row ID -> affected feature
		affectedFeatures := map[int64]database.AffectedFeature{}
		for _, f := range vuln.Affected {
			err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.ByName[f.FeatureType], f.FixedInVersion).Scan(&affectedID)
			if err != nil {
				return nil, util.HandleError("insertVulnerabilityAffected", err)
			}
			affectedFeatures[affectedID] = f
		}
		vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures}
	}

	return vulnFeature, nil
}

// insertVulnerabilities inserts a set of unique vulnerabilities into database,
// under the assumption that all vulnerabilities are valid.
func insertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) {
	var (
		vulnID  int64
		vulnIDs = make([]int64, 0, len(vulnerabilities))
		vulnMap = map[database.VulnerabilityID]struct{}{}
	)

	for _, v := range vulnerabilities {
		key := database.VulnerabilityID{
			Name:      v.Name,
			Namespace: v.Namespace.Name,
		}

		// Ensure uniqueness of vulnerability IDs
		if _, ok := vulnMap[key]; ok {
			return nil, errors.New("inserting duplicated vulnerabilities is not allowed")
		}
		vulnMap[key] = struct{}{}
	}

	//TODO(Sida): Change to bulk insert.
	stmt, err := tx.Prepare(insertVulnerability)
	if err != nil {
		return nil, util.HandleError("insertVulnerability", err)
	}

	defer stmt.Close()
	for _, vuln := range vulnerabilities {
		err := stmt.QueryRow(vuln.Name, vuln.Description,
			vuln.Link, &vuln.Severity, &vuln.Metadata,
			vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID)
		if err != nil {
			return nil, util.HandleError("insertVulnerability", err)
		}

		vulnIDs = append(vulnIDs, vulnID)
	}

	return vulnIDs, nil
}

func LockFeatureVulnerabilityCache(tx *sql.Tx) error {
	_, err := tx.Exec(lockVulnerabilityAffects)
	if err != nil {
		return util.HandleError("lockVulnerabilityAffects", err)
	}
	return nil
}

// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID
// to affected feature rows and caches them.
func CacheVulnerabiltyAffectedNamespacedFeature(tx *sql.Tx, affected map[int64]affectedFeatureRows) error {
	// Prevent InsertNamespacedFeatures to modify it.
	err := LockFeatureVulnerabilityCache(tx)
	if err != nil {
		return err
	}

	vulnIDs := []int64{}
	for id := range affected {
		vulnIDs = append(vulnIDs, id)
	}

	rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs))
	if err != nil {
		return util.HandleError("searchVulnerabilityPotentialAffected", err)
	}

	defer rows.Close()

	relation := []affectRelation{}
	for rows.Next() {
		var (
			vulnID   int64
			nsfID    int64
			fVersion string
			addedBy  int64
		)

		err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy)
		if err != nil {
			return util.HandleError("searchVulnerabilityPotentialAffected", err)
		}

		candidate, ok := affected[vulnID].rows[addedBy]

		if !ok {
			return errors.New("vulnerability affected feature not found")
		}

		if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat,
			fVersion,
			candidate.AffectedVersion); err == nil {
			if in {
				relation = append(relation,
					affectRelation{
						vulnerabilityID:     vulnID,
						namespacedFeatureID: nsfID,
						addedBy:             addedBy,
					})
			}
		} else {
			return err
		}
	}

	//TODO(Sida): Change to bulk insert.
	for _, r := range relation {
		result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy)
		if err != nil {
			return util.HandleError("insertVulnerabilityAffectedNamespacedFeature", err)
		}

		if num, err := result.RowsAffected(); err == nil {
			if num <= 0 {
				return errors.New("Nothing cached in database")
			}
		} else {
			return err
		}
	}

	log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation))
	return nil
}

func DeleteVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) error {
	defer monitoring.ObserveQueryTime("DeleteVulnerability", "all", time.Now())

	vulnIDs, err := MarkVulnerabilitiesAsDeleted(tx, vulnerabilities)
	if err != nil {
		return err
	}

	if err := InvalidateVulnerabilityCache(tx, vulnIDs); err != nil {
		return err
	}
	return nil
}

func InvalidateVulnerabilityCache(tx *sql.Tx, vulnerabilityIDs []int64) error {
	if len(vulnerabilityIDs) == 0 {
		return nil
	}

	// Prevent InsertNamespacedFeatures to modify it.
	err := LockFeatureVulnerabilityCache(tx)
	if err != nil {
		return err
	}

	//TODO(Sida): Make a nicer interface for bulk inserting.
	keys := make([]interface{}, len(vulnerabilityIDs))
	for i, id := range vulnerabilityIDs {
		keys[i] = id
	}

	_, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...)
	if err != nil {
		return util.HandleError("removeVulnerabilityAffectedFeature", err)
	}

	return nil
}

func MarkVulnerabilitiesAsDeleted(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]int64, error) {
	var (
		vulnID  sql.NullInt64
		vulnIDs []int64
	)

	// mark vulnerabilities deleted
	stmt, err := tx.Prepare(removeVulnerability)
	if err != nil {
		return nil, util.HandleError("removeVulnerability", err)
	}

	defer stmt.Close()
	for _, vuln := range vulnerabilities {
		err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID)
		if err != nil {
			return nil, util.HandleError("removeVulnerability", err)
		}
		if !vulnID.Valid {
			return nil, util.HandleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database"))
		}
		vulnIDs = append(vulnIDs, vulnID.Int64)
	}
	return vulnIDs, nil
}

// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in
// database and the order of output array is not guaranteed.
func FindLatestDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
	return FindVulnerabilityIDs(tx, vulnIDs, true)
}

func FindNotDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
	return FindVulnerabilityIDs(tx, vulnIDs, false)
}

func FindVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) {
	if len(vulnIDs) == 0 {
		return nil, nil
	}

	vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{}
	keys := make([]interface{}, len(vulnIDs)*2)
	for i, vulnID := range vulnIDs {
		keys[i*2] = vulnID.Name
		keys[i*2+1] = vulnID.Namespace
		vulnIDMap[vulnID] = sql.NullInt64{}
	}

	query := ""
	if withLatestDeleted {
		query = querySearchLastDeletedVulnerabilityID(len(vulnIDs))
	} else {
		query = querySearchNotDeletedVulnerabilityID(len(vulnIDs))
	}

	rows, err := tx.Query(query, keys...)
	if err != nil {
		return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Query", err)
	}

	defer rows.Close()
	var (
		id     sql.NullInt64
		vulnID database.VulnerabilityID
	)
	for rows.Next() {
		err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace)
		if err != nil {
			return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Scan", err)
		}
		vulnIDMap[vulnID] = id
	}

	ids := make([]sql.NullInt64, len(vulnIDs))
	for i, v := range vulnIDs {
		ids[i] = vulnIDMap[v]
	}

	return ids, nil
}

func FindPagedVulnerableAncestries(tx *sql.Tx, vulnID int64, limit int, currentToken pagination.Token, key pagination.Key) (database.PagedVulnerableAncestries, error) {
	vulnPage := database.PagedVulnerableAncestries{Limit: limit}
	currentPage := page.Page{0}
	if currentToken != pagination.FirstPageToken {
		if err := key.UnmarshalToken(currentToken, &currentPage); err != nil {
			return vulnPage, err
		}
	}

	if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
		&vulnPage.Name,
		&vulnPage.Description,
		&vulnPage.Link,
		&vulnPage.Severity,
		&vulnPage.Metadata,
		&vulnPage.Namespace.Name,
		&vulnPage.Namespace.VersionFormat,
	); err != nil {
		return vulnPage, util.HandleError("searchVulnerabilityByID", err)
	}

	// the last result is used for the next page's startID
	rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
	if err != nil {
		return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err)
	}
	defer rows.Close()

	ancestries := []affectedAncestry{}
	for rows.Next() {
		var ancestry affectedAncestry
		err := rows.Scan(&ancestry.id, &ancestry.name)
		if err != nil {
			return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err)
		}
		ancestries = append(ancestries, ancestry)
	}

	lastIndex := 0
	if len(ancestries)-1 < limit {
		lastIndex = len(ancestries)
		vulnPage.End = true
	} else {
		// Use the last ancestry's ID as the next page.
		lastIndex = len(ancestries) - 1
		vulnPage.Next, err = key.MarshalToken(page.Page{ancestries[len(ancestries)-1].id})
		if err != nil {
			return vulnPage, err
		}
	}

	vulnPage.Affected = map[int]string{}
	for _, ancestry := range ancestries[0:lastIndex] {
		vulnPage.Affected[int(ancestry.id)] = ancestry.name
	}

	vulnPage.Current, err = key.MarshalToken(currentPage)
	if err != nil {
		return vulnPage, err
	}

	return vulnPage, nil
}
